# runners/sq_hypersense_runner.py
from __future__ import annotations
from typing import Dict, Any
import time

import ConfigSpace as CS
from optuna.distributions import FloatDistribution, IntDistribution, CategoricalDistribution

from objective import Objective
from loggers import ExperimentLogger
from runners.random_runner import _canonicalize

# HyperSense components
from hypersense.pipeline import HyperSensePipeline
from hypersense.optimizer.optuna_optimizer import OptunaOptimizer
from hypersense.sampler.stratified_sampler import StratifiedSampler
from hypersense.importance.n_rrelieff import NRReliefFAnalyzer
from hypersense.strategy.sequential_grouping import SequentialGroupingStrategy


def _cs_to_hypersense_space(cs: CS.ConfigurationSpace):
    """
    Convert ConfigSpace to HyperSense-required Optuna-style search_space.
    Returns: (space_dict, fixed_config)
    """
    space: Dict[str, Any] = {}
    fixed: Dict[str, Any] = {}

    from ConfigSpace.hyperparameters import (
        Constant, CategoricalHyperparameter,
        UniformIntegerHyperparameter, UniformFloatHyperparameter,
        IntegerHyperparameter, FloatHyperparameter
    )

    for hp in cs.values():
        name = hp.name

        # Constant / single category
        if isinstance(hp, Constant):
            fixed[name] = hp.value
            continue
        if isinstance(hp, CategoricalHyperparameter) and len(hp.choices) == 1:
            fixed[name] = hp.choices[0]
            continue

        # Categorical
        if isinstance(hp, CategoricalHyperparameter):
            space[name] = CategoricalDistribution(choices=list(hp.choices))
            continue

        # Integer
        if isinstance(hp, (UniformIntegerHyperparameter, IntegerHyperparameter)):
            low, high = int(hp.lower), int(hp.upper)
            log = bool(getattr(hp, "log", False) or getattr(hp, "log_scale", False))
            step = getattr(hp, "q", None) or 1
            space[name] = IntDistribution(low=low, high=high, log=log, step=step)
            continue

        # Float
        if isinstance(hp, (UniformFloatHyperparameter, FloatHyperparameter)):
            low, high = float(hp.lower), float(hp.upper)
            log = bool(getattr(hp, "log", False) or getattr(hp, "log_scale", False))
            step = getattr(hp, "q", None)
            space[name] = FloatDistribution(low=low, high=high, log=log, step=step)
            continue

        raise NotImplementedError(f"Unsupported HP type in HyperSense space: {type(hp).__name__} ({name})")

    return space, fixed


def run_sq_hypersense(*,
                      seed: int,
                      bench: str,
                      cs: CS.ConfigurationSpace,
                      obj: Objective,
                      budget_n: int,
                      logger: ExperimentLogger,
                      method_name: str = "SQ-HyperSense",
                      sample_ratio: float = 0.6,
                      init_trials: int = 100,
                      verbose: bool = False,
                      quiet: bool = True,
                      top_k: int | None = None):
    """
    Run HyperSense SequentialGroupingStrategy on NASBench301.
    """
    # 1) Convert search space
    hs_space, fixed = _cs_to_hypersense_space(cs)
    dim = len(hs_space)
    if top_k is None:
        top_k = max(1, dim // 3)

    # 2) Default configuration
    try:
        default_cfg = dict(cs.get_default_configuration())
    except Exception:
        default_cfg = dict(cs.sample_configuration())
    default_cfg.update(fixed)

    # 3) Objective function (with canonicalize)
    n_eval = 0
    best = float("inf")

    def nb301_objective(config: Dict[str, Any], dataset=None):
        nonlocal n_eval, best
        merged = dict(config)
        merged.update(fixed)

        try:
            merged = _canonicalize(cs, merged)
        except Exception:
            return 1e9  # Return poor value to avoid trial crash

        t0 = time.perf_counter()
        loss, sim_t = obj.evaluate(merged)
        elapsed = time.perf_counter() - t0

        n_eval += 1
        if loss < best:
            best = loss

        logger.log(dict(
            seed=seed, method=method_name, bench=bench,
            n_eval=n_eval,
            sim_time=sim_t,
            elapsed_time=elapsed,
            best_score=1 - best,
            curr_score=1 - loss,
            config=merged,
        ))
        return loss

    # 4) Fake dataset (required by interface)
    full_dataset = []

    # 5) Build pipeline
    pipeline = HyperSensePipeline(
        search_space=hs_space,
        full_dataset=full_dataset,
        objective_fn=nb301_objective,
        test_fn=nb301_objective,
        sampler_class=StratifiedSampler,
        initial_optimizer_class=OptunaOptimizer,
        whole_optimizer_class=OptunaOptimizer,
        importance_analyzer_class=NRReliefFAnalyzer,
        default_config=default_cfg,
        mode="min",
        seed=seed,
        strategy_class=SequentialGroupingStrategy,
    )

    # 6) Run pipeline
    pipeline.run(
        sample_ratio=sample_ratio,
        initial_trials=init_trials,
        total_trials=budget_n,
        top_k=top_k,
        verbose=verbose,
        quiet=quiet,
    )

